{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92bc711b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import os\n",
    "import pandas as pd\n",
    "import time\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "from processing.load_datasets import load_datasets\n",
    "from processing.configs import COLORS_4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1c13eb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets_names = [\"rts\", \"pdl\", \"ioc\", \"mjf\"]\n",
    "datasets = load_datasets(datasets_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3136a304",
   "metadata": {},
   "outputs": [],
   "source": [
    "explained_variances = {}\n",
    "for dataset in datasets_names:\n",
    "    X = datasets[dataset]\n",
    "\n",
    "    # Sample 10k points for PCA\n",
    "    N_SAMPLE = 10000\n",
    "    if X.shape[0] > N_SAMPLE:\n",
    "        np.random.seed(42)\n",
    "        sample_indices = np.random.choice(X.shape[0], N_SAMPLE, replace=False)\n",
    "        X = X[sample_indices]\n",
    "\n",
    "    X_scaled = StandardScaler().fit_transform(X)\n",
    "\n",
    "    # Fit PCA (all components)\n",
    "    pca = PCA()\n",
    "    pca.fit(X_scaled)\n",
    "\n",
    "    explained_variances[dataset] = pca.explained_variance_ratio_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb13bbad",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 6))\n",
    "for i,dataset in enumerate(datasets_names):\n",
    "    cum_evr = np.cumsum(explained_variances[dataset])[:20]\n",
    "\n",
    "    # Plot cumulative explained variance\n",
    "    plt.plot(np.arange(1, len(cum_evr) + 1), cum_evr, marker=\"x\", label=dataset.upper(), color=COLORS_4[i])\n",
    "    plt.xlabel(\"Number of principal components\", fontsize = 12, fontweight='bold')\n",
    "    plt.ylabel(\"Cumulative explained variance ratio\", fontsize = 12, fontweight='bold')\n",
    "    plt.ylim(0, 1.01)\n",
    "    plt.xticks(np.arange(1, 21))\n",
    "    plt.grid(True, alpha=0.3)\n",
    "plt.legend()\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"images/pca_explained_variance.png\", dpi=300, bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6cca006",
   "metadata": {},
   "outputs": [],
   "source": [
    "for dataset in datasets_names:\n",
    "    print(f\"{dataset.upper()}: {np.sum(explained_variances[dataset][:2]):.4f} explained variance with 2 components\")\n",
    "    print(f\"{dataset.upper()}: {np.sum(explained_variances[dataset][:10]):.4f} explained variance with 10 components\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b4588a0",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dr_eval",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
